package geektime.spring.springbucks.filter;

import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.StatementProxy;
import com.alibaba.druid.sql.dialect.h2.parser.H2StatementParser;
import com.alibaba.druid.sql.parser.Lexer;
import com.alibaba.druid.sql.parser.Token;
import geektime.spring.springbucks.exception.InterceptException;
import geektime.spring.springbucks.model.TwoTuple;
import lombok.SneakyThrows;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Stack;

public class InterceptFilter extends FilterEventAdapter {
    @SneakyThrows
    @Override
    protected void statementExecuteBefore(StatementProxy statement, String sql) {
        if (sql.contains(";")) {
            Lexer lexer = new H2StatementParser(sql).getLexer();
            Token token = lexer.token();
            Token prevToken = token;
            int size = 1;
            while (token != Token.EOF) {
                if (token == Token.SEMI && prevToken != Token.SEMI) {
                    size++;
                }
                prevToken = token;
                lexer.nextToken();
                token = lexer.token();
            }
            if (prevToken == Token.SEMI) {
                size--;
            }
            if (size > 1) {
                throw new InterceptException("不允许SQL拼接!");
            }
        }
        if (sql.toLowerCase().contains(" in")) {
            Map<Integer, Integer> sizeMap = new HashMap<>();
            int lparenPos = 0;

            Stack<TwoTuple<Integer, Boolean>> lparenPosStack = new Stack();

            Lexer lexer = new H2StatementParser(sql).getLexer();
            Token token = lexer.token();
            Token prevToken = token;

            while (token != Token.EOF) {
                TwoTuple<Integer, Boolean> peek;
                if (!lparenPosStack.isEmpty()) {
                    peek = lparenPosStack.peek();
                } else {
                    peek = new TwoTuple(0, false);
                }
                boolean countFlag = peek.getSecond();
                if (token == Token.LPAREN) {
                    lparenPos++;
                    lparenPosStack.push(new TwoTuple(lparenPos, prevToken == Token.IN));
                }
                if (token == Token.RPAREN) {
                    lparenPosStack.pop();
                }

                if (token == Token.COMMA && countFlag) {
                    sizeMap.put(peek.getFirst(), sizeMap.getOrDefault(peek.getFirst(), 0) + 1);
                }

                prevToken = token;
                lexer.nextToken();
                token = lexer.token();
            }

            int count = sizeMap.entrySet().stream().map((entry) -> entry.getValue()).max(Comparator.comparing(Integer::intValue)).orElse(0);

            if (count + 1 > 10) {
                throw new InterceptException("IN列表不能超过10个");
            }
        }


        super.statementExecuteBefore(statement, sql);
    }
}
